-
Notifications
You must be signed in to change notification settings - Fork 417
add weight_sum_fp32 config #2519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a new configuration flag, float32_weight_sum, to control the precision of the weighted sum operation in the Mixture of Experts (MoE) layers. The changes are well-implemented and provide useful flexibility for balancing performance and numerical precision.
🔍 General Feedback
- The addition of the
float32_weight_sumflag is a good feature for optimizing MoE layers. - The implementation in
src/MaxText/layers/moe.pycorrectly applies the conditional casting based on the new configuration. - A minor style suggestion was made to improve comment consistency.
| cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly. | ||
| float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product | ||
| float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax | ||
| float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Nit: For consistency and clarity, it's better to use "MoE" instead of "moe".
| float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in moe | |
| float32_weight_sum: True # whether to use full fp32 precision for weight_sum during final unpermute in MoE |
update update Update base.yml Update moe.py Update moe.py
6f0349d to
ee4e8cc
Compare
Description
Add config flag
weight_sum_fp32for whether to use full fp32 precision for weight_sum during final unpermute in moeTests
final eval loss at 300 steps
2.394 (cloudlog)(https://cloudlogging.app.goo.gl/Q5o2tac9aypGGMyV6)
2.393 (cloudlog)(https://cloudlogging.app.goo.gl/L2N43dAZiHap1Djk7)
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.